'''线性拉伸百分比'''
import numpy as np
from osgeo import gdal
import tqdm
# 受postgis影响，重新设置环境变量
import os
os.environ['CPL_ZIP_ENCODING'] = 'UTF-8'
os.environ['PROJ_LIB'] = r'D:\anaconda3\envs\pytorch_GPU\Lib\site-packages\pyproj\proj_dir\share\proj'
os.environ['GDAL_DATA'] = r'D:\anaconda3\envs\pytorch_GPU\Library\share'

#  读取tif数据集
def readTif(fileName, xoff=0, yoff=0, data_width=0, data_height=0):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName + "文件无法打开")
    #  栅格矩阵的列数
    width = dataset.RasterXSize
    #  栅格矩阵的行数
    height = dataset.RasterYSize
    #  波段数
    bands = dataset.RasterCount
    #  获取数据
    if (data_width == 0 and data_height == 0):
        data_width = width
        data_height = height
    data = dataset.ReadAsArray()

    #  获取仿射矩阵信息
    geotrans = dataset.GetGeoTransform()
    #  获取投影信息
    proj = dataset.GetProjection()
    # return width, height, bands, data, geotrans, proj
    return data, width, height, bands, geotrans, proj


#  保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape

    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), gdal.GDT_Byte)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    print(f'波段总和{im_bands}')
    for i in tqdm.tqdm(range(im_bands)):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


def truncated_linear_stretch(image, truncated_value, max_out=255, min_out=0):
    def gray_process(gray1):
        gray = np.where(gray1 > 0,gray1,-9999)
        #
        truncated_down = np.percentile(gray, truncated_value)
        truncated_up = np.percentile(gray, 100 - truncated_value)
        gray = (gray - truncated_down) / (truncated_up - truncated_down) * (max_out - min_out) + min_out
        gray[gray < min_out] = min_out
        gray[gray > max_out] = max_out
        if (max_out <= 255):
            gray = np.uint8(gray)
        elif (max_out <= 65535):
            gray = np.uint16(gray)
        return gray

    #  如果是多波段
    if (len(image.shape) == 3):
        image_stretch = []
        for i in range(image.shape[0]):
            gray = gray_process(image[i])
            image_stretch.append(gray)
        image_stretch = np.array(image_stretch)
    #  如果是单波段
    else:
        image_stretch = gray_process(image)
    return image_stretch
def compress(array_data, rows, cols, bands,low_per_raw=0.2,high_per_raw=0.2):
    """
    Input:
    origin_16:16位图像路径
    low_per=0.4   0.4%分位数，百分比截断的低点
    high_per=99.6  99.6%分位数，百分比截断的高点
    Output:
    output:8位图像路径
    """
    # 获取图像的数组文件
    # width, height, bands, data, geotrans, proj
    # array_data, rows, cols, bands = readTif(origin_16) # array_data, (4, 36786, 37239) ,波段，行，列
    print("1shape:", array_data.shape)

    # 这里控制要输出的是几位
    compress_data = np.zeros((bands,rows, cols),dtype="uint8")
    mean_value_list = [18.05,24.82,22.33]  # 均值
    for i in range(bands):
        # 得到百分比对应的值，得到0代表的黑边的占比
        cnt_array = np.where(array_data[i, :, :], 0, 1)
        mean_value = mean_value_list[i]
        num0 = np.sum(cnt_array)

        kk = num0 / (rows * cols)    # 得到0的比例
        print(f'0的占比为{kk}')
        # 这里是去掉黑边0值，否则和arcgis的不一样，这样也有偏差，只算剩下的百分比

        low_per = low_per_raw + kk - low_per_raw * kk  # (A*x-A*KK)/(A-A*KK)=0.01, X = 0.99KK+0.01
        # 计算出左边截断的百分比
        low_per = low_per * 100
        high_per = (1 - high_per_raw) # A*x/(A-A*KK) = 0.04, X =  0.04-(0.04*kk)
        # 计算出截断右边的百分比
        high_per = 100 - high_per * 100
        print("baifen:", low_per, high_per)
        # 找到一组数的分位数值
        cutmin = np.percentile(array_data[i, :, :], low_per)
        cutmax = np.percentile(array_data[i, :, :], high_per)
        print("duandian:",cutmin,cutmax)

        data_band = array_data[i]
        # 进行截断
        data_band[data_band<cutmin] = cutmin
        data_band[data_band>cutmax] = cutmax
        # 进行缩放
        # compress_data[i,:, :] = np.around( (data_band[:,:] - cutmin) *255/(cutmax - cutmin) )
        # 归一化函数
        variance = 19.238129# 标准差21.73333
        compress_data[i, :, :] = np.around(((data_band[:, :]- mean_value) / (variance))*255 )





    print("maxANDmin：",np.max(compress_data),np.min(compress_data))
    # 下面可以直接将处理后的遥感图像输出，在神经网络预测的时候可以删掉这一句
    # write_img(origin_16, compress_data)
    return compress_data
if __name__ == '__main__':

    fileName = r"G:\GF-2test_result\GF2_PMS1_E110.6_N21.2_20210223_L1A0005501538_pansharpen.tif"
    SaveName = r"G:\GF-2test_result\GF2_PMS1_E110.6_N21.21538_lin0_zscore1.tif"
    data, width, height, bands, geotrans, proj = readTif(fileName)
    # data_stretch = truncated_linear_stretch(data, 0.25, max_out=255, min_out=0)
    data_stretch = compress(data, height,width, bands,low_per_raw=0.02,high_per_raw=0.98)
    writeTiff(data_stretch, geotrans, proj, SaveName)